# 1. 按照原data的顺序执行
# 2. 每次都是固定的随机挑选point，进行执行2000次
# 3. 在2基础上，更新方法中多了权重
# 4. 在2基础上，修改为pocket，数据分为训练和测试，每次执行50次，测试集合的平均误差
# 5. 在4基础上，将50次返回的w替换为pla的w
# 6. 在4,基础上，执行100次，测试2000time后查看效果

import random
import re
import copy
import pla_plot
# randomly generate an array of datas
# def randomData(num):


def readDataFrom(filename):
    result = []
    seperator = re.compile('\t|\b| |\n')
    with open(filename, 'r') as f:
        line = f.readline()
        while line:
            temp = seperator.split(line)[0:-1]
            abc = [float(x) for x in temp]
            result.append(abc)
            line = f.readline()
    return result
    
def pla(datas):
    size = len(datas)
    if size <= 1:
        return;
    err_i = -1  # 标记当前修改的data行
    dms = len(datas[0])
    if dms == 0:
        return;
    w = [0 for x in range(0, dms)]
    run_times = 0
    last_pause = size
    now = 0
    while True:
        run_times = run_times + 1 # 整个数据循环的圈数
        while now != last_pause:    # 转一圈之后，两个碰在一起
            p = 0
            now %= size # 当前在size中的位置
            for x in range(0, dms-1):
                p += w[x] * datas[now][x]
            p += w[-1]
            if p <= 0 and datas[now][-1] > 0 or p >0 and datas[now][-1] < 0:
                err_i = now
                last_pause = err_i
                if last_pause == 0:
                    last_pause == size
                now += 1
                break
            now += 1
                
        # 更新w（w_0放在末尾）
        if err_i != -1:
            for x in range(0, dms - 1):
                w[x] += datas[err_i][-1]*datas[err_i][x]
            w[-1] += datas[err_i][-1]
            err_i = -1
        else:break;
        
    return [w, run_times]

# 永远保证当前时刻是[0,n)
# 每次交换当前i的随机的数
def randomIndex(n):
    index = [i for i in range(0,n)]
    def swap(l,x,y):
        l[x] = l[x]+l[y]
        l[y] = l[x] - l[y]
        l[x] = l[x] - l[y]
    for i in range(0,n):
        swap(index,i,int(random.random()*n))
    return index

  
def plaImproved(datas,n = 1):

    size = len(datas)
    if size<=1:
        return;
    err_i = -1

    dms = len(datas[0])
    if dms == 0:
        return;
    para = [0 for x in range(0,dms)]
    run_times = 0
    index = randomIndex(size)
    last_pause = size
    i = 0
    while True:
        #if run_times>=50:
            #break
        run_times+=1

        #for i in range(0, size):
        while i != last_pause:
            p = 0
            i %= size

            for x in range(0, dms - 1):
                p += para[x] * datas[index[i]][x]
            p += para[-1]
            if p <= 0 and datas[index[i]][-1] > 0 or p > 0 and datas[index[i]][-1] < 0:#ignore datas[i][-1] == 0
                err_i = index[i]
                break; #遇到错误推出循环
            i+=1
        if err_i != -1:
            for x in range(0, dms - 1): #用这个错误来更新参数
                para[x] = para[x]+ n* datas[err_i][-1] * datas[err_i][x]  # update the parameters
            para[-1] += n * datas[err_i][-1]
            last_pause = i
            if last_pause == 0:
                last_pause = size
            i+=1
            err_i = -1;
        else:break;

    return [para,run_times]
  
# 测试w够不够好
def computeER(w, datas):
    size=len(datas)
    if size <= 1:
        return
    dms = len(datas[0])
    if dms == 0:
        return
    error = 0
    for i in range(0, size):
        p = 0
        for x in range(0, dms-1):
            p += w[x]*datas[i][x]
        p += w[-1]
        if p <= 0 and datas[i][-1] > 0 or p > 0 and datas[i][-1] < 0:#ignore datas[i][-1] == 0
            error += 1
         
    return error/size
    
  
'''
max_time: 最大更新次数
greedy： 是否总是手握当前最优w
'''  
def pocket(datas, max_time=50, greedy=1):
    size=len(datas)
    if size <= 1:
        return
    err_i = -1
    dms = len(datas[0])
    if dms == 0:
        return
    
    w = [0 for x in range(0,dms)]
    new_w = [0 for x in range(0,dms)]
    new_error = 0
    last_error = size
    run_times = 0
    
    while True:
        index = randomIndex(size)
        if run_times>max_time:
            break
        run_times += 1
        
        for i in range(0, size):
            p = 0
            for x in range(0, dms-1):
                p += new_w[x]*datas[index[i]][x]
            p += new_w[-1]
            if p <= 0 and datas[index[i]][-1] > 0 or p > 0 and datas[index[i]][-1] < 0:#ignore datas[i][-1] == 0
                err_i = index[i]
                break
        if err_i != -1:
            for x in range(0, dms - 1): #用这个错误来更新参数
                new_w[x] += datas[err_i][-1] * datas[err_i][x]  # update the parameters
            new_w[-1] += datas[err_i][-1]
            
            if greedy == 1:           
                for i in range(0, size):
                    p = 0
                    for x in range(0, dms-1):
                        p += new_w[x]*datas[index[i]][x]
                    p += new_w[-1]
                    if p <= 0 and datas[index[i]][-1] > 0 or p > 0 and datas[index[i]][-1] < 0:#ignore datas[i][-1] == 0
                        new_error += 1
                        
                if (new_error < last_error):
                    w = copy.deepcopy(new_w)    # 如果不是deepcopy，就等于只是引用
                    last_error = new_error
                new_error = 0
            
            err_i = -1
        else: break
    
    if greedy == 0:
        return [new_w, run_times]
    else:
        return [w, run_times]
                
    
    
if __name__=="__main__":
    all = readDataFrom('hw1_15_train.dat')
    print("15题: 修正次数：", pla(all)[1])
    time = 0
    for i in range(0,2000):
        time += plaImproved(all)[1]
    print("16: 2000次平均循环次数： ", time/2000)
    
    
    time = 0
    for i in range(0, 2000):
        time += plaImproved(all, 0.5)[1]
    print("17: 循环次数", time/2000)
    
    all = readDataFrom('hw1_18_train.dat')
    test = readDataFrom('hw1_18_test.dat')
    error = 0
    for i in range(0, 2000):
        w = pocket(all)[0]
        error+=computeER(w, test)
    print("18: 平均误差率： ", error/2000)
    
    error = 0
    for i in range(0, 2000):
        w = pocket(all,greedy=0)[0]
        error+=computeER(w, all)
    print("19: 平均误差率： ", error/2000)


    error = 0
    for i in range(0, 2000):
        w = pocket(all, 100)[0]
        error+=computeER(w, test)
    print("20: 平均误差率： ", error/2000)    